from tflite_support.metadata_writers import image_segmenter
from tflite_support.metadata_writers import writer_utils
from tflite_support.metadata_writers import metadata_info
from tflite_support import metadata_schema_py_generated as _metadata_fb


def save_model_with_metadata(tflite_model_path):
    ImageSegmenterWriter = image_segmenter.MetadataWriter

    NORMALIZATION_MEAN = 0.0
    NORMALIZATION_STD = 255.0

    writer = ImageSegmenterWriter.create_for_inference(
        writer_utils.load_file(tflite_model_path), [NORMALIZATION_MEAN], [NORMALIZATION_STD],
        [])

    # Verify the metadata generated by metadata writer.
    print(writer.get_metadata_json())


def save_model_with_metadata_advanced(tflite_model_path, save_path):
    model_buffer = writer_utils.load_file(tflite_model_path)

    # Create general model information.
    general_md = metadata_info.GeneralMd(
        name="ImageSegmenter",
        version="v1",
        description="Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.",
        author="Vaishak Nair",
        licenses="Apache License. Version 2.0")

    # Create input tensor information.
    input_md = metadata_info.InputImageTensorMd(
        name="input image",
        description=("Input image to be classified. The expected image is "
                     "512 x 512, with three channels (red, blue, and green) per "
                     "pixel. Each element in the tensor is a value between min and "
                     "max, where (per-channel) min is [0] and max is [255]."),
        norm_mean=[0.0],
        norm_std=[255.0],
        color_space_type=_metadata_fb.ColorSpaceType.RGB,
        tensor_type=writer_utils.get_input_tensor_types(model_buffer)[0])


    # Create output tensor information.
    output_md = metadata_info.ClassificationTensorMd(
        name="probability",
        description="Probabilities mask (0 - 1) for the input image. 1 being completely foreground and 0 being completely background.",
        # label_files=[
        #     metadata_info.LabelFileMd(file_path="mobilenet_labels.txt",
        #                               locale="en")
        # ],
        tensor_type=writer_utils.get_output_tensor_types(model_buffer)[0])

    ImageSegmenterWriter = image_segmenter.MetadataWriter
    writer = ImageSegmenterWriter.create_from_metadata_info(
        model_buffer, general_md, input_md, output_md)
    print(writer.get_metadata_json())

    # Populate the metadata into the model.
    writer_utils.save_file(writer.populate(), save_path)


if __name__ == "__main__":
    save_model_with_metadata_advanced("/home/vaishak/Downloads/tflite/u2netlite/40.tflite",
                                      "/home/vaishak/Downloads/tflite/u2netlite/40_with_metadata.tflite")
